from typing import Any, Optional, Union
import warnings
import numpy as np

from scipy.optimize import linprog
from scipy.spatial import ConvexHull

from RACH_Space_model import RACH_Space_model

from ..basemodel import BaseLabelModel
from ..dataset import BaseDataset
from ..dataset.utils import check_weak_labels

############utilities###########
############utilities###########
############utilities###########

def extreme_points(matrix):
    # Transpose the matrix to get the input vectors as rows
    vectors = matrix.T

    hull = ConvexHull(vectors)
    extreme_point_indices = np.unique(hull.simplices.flatten())
    extreme_points = vectors[extreme_point_indices]

    # Transpose the extreme points back to get the output matrix
    output_matrix = extreme_points.T

    return output_matrix

def remainder_columns(original_matrix, subset_matrix):
    # Find the indices of the columns in the original matrix that are not present in the subset matrix
    remainder_indices = [i for i in range(original_matrix.shape[1]) if not np.any(np.all(original_matrix[:, i:i+1] == subset_matrix, axis=0))]

    # Extract the remainder columns from the original matrix
    remainder_matrix = original_matrix[:, remainder_indices]

    return remainder_matrix

def in_hull(matrix, column_vector):
    n_points = matrix.shape[1]
    n_dim = matrix.shape[0]
    c = np.zeros(n_points)
    A = np.r_[matrix, np.ones((1, n_points))]
    b = np.c_[column_vector.T, np.ones((1, 1))]
    lp = linprog(c, A_eq=A, b_eq=b)
    return lp.success

def compute_A_matrix(weak_probabilities):
    m, n, k = weak_probabilities.shape
    error_amatrix = np.zeros((m, n, k))

    for i, weak_signal in enumerate(weak_probabilities):
        error_amatrix[i] = 2 * weak_signal

    return error_amatrix

def compute_b_vector(weak_probabilities, error_bounds):
    m, n, k = weak_probabilities.shape

    constants = []

    our_a = np.reshape(weak_probabilities, (m, n*k, 1))

    for i, weak_signal in enumerate(our_a):

        # constants for error constraints

        constant = weak_signal
        constants.append(constant)

    # set up error upper bounds constraints
    constants = np.sum(constants, axis=1) / n

    assert len(constants.shape) == len(error_bounds.shape)

    bounds = -error_bounds + constants + 1

    return bounds

def build_constraints(a_matrix, bounds):

    constraints = dict()
    constraints['A'] = a_matrix
    constraints['b'] = bounds
    constraints['gamma'] = np.zeros(bounds.shape)

    # temp for now
    constraints['c'] = np.zeros(a_matrix.shape)

    return constraints

def set_up_constraint(weak_probabilities, error_bounds):

    constraint_set = dict()

    ##################################################
    error_amatrix = compute_A_matrix(weak_probabilities)
    bounds = compute_b_vector(weak_probabilities, error_bounds)

    error_set = build_constraints(error_amatrix, bounds)
    constraint_set['error'] = error_set

    return constraint_set

############utilities###########
############utilities###########
############utilities###########


def nonabstain_data(weak_signals_probas):
    """Convert weak signals with abstain (-1) to non-abstain (1/k)"""
    m, n, k = weak_signals_probas.shape

    weak_signals_probas[weak_signals_probas < 0] = 1/k

    return weak_signals_probas

def limit_rows(arr, method="mean_chunks"):

    if arr.shape[0] <= 5:
        return arr

    if method == "mean_chunks":
        x, y = divmod(arr.shape[0], 5)
        chunks = []
        start = 0
        for i in range(5):
            if i < y:
                chunks.append(np.mean(arr[start:start + x + 1], axis=0))
                start += x + 1
            else:
                chunks.append(np.mean(arr[start:start + x], axis=0))
                start += x
        return np.array(chunks)

def convert_weak_signals_format_wrench_to_cll(L, n_class):
    n_examples, n_weak = L.shape

    W = np.zeros((n_weak, n_examples, n_class))
    for w in range(n_weak):
        for i in range(n_examples):
            c = L[i, w]

            if c == -1:
                W[w, i, :] = -1
            else:
                W[w, i, c] = 1
    return W

def replace_negatives_with_fraction(L, k):
    """Replace negative values in L with 1/k."""
    L[L < 0] = 1/k
    return L

class RACH_Space(BaseLabelModel):
    def __init__(self, model_class: RACH_Space_model, **kwargs: Any):
        super().__init__()
        self.model = model_class(**kwargs)

    def fit(self,
            dataset_train: Union[BaseDataset, np.ndarray],
            n_class: Optional[int] = None):

        # get weak labels
        L = check_weak_labels(dataset_train)

        print("This is the weak_labels for RACH-Space:", L)
        print("This is the shape of weak_labels for RACH-Space", L.shape)

        if isinstance(dataset_train, BaseDataset):
            if n_class is not None:
                assert n_class == dataset_train.n_class
            else:
                n_class = dataset_train.n_class

        weak_signals_probas = convert_weak_signals_format_wrench_to_cll(L, n_class)

        weak_signals_probas = nonabstain_data(weak_signals_probas)

        # Limite the number of weak signals to be less or equal to 7
        weak_signals_probas = limit_rows(weak_signals_probas, method="mean_chunks")

        m, n, k = weak_signals_probas.shape

        weak_errors = np.ones((m, 1)) * (2 / k - 2 / (k ** 2)) * k

        constraint_set = set_up_constraint(weak_signals_probas, weak_errors)['error']

        A = constraint_set['A']

        A = np.reshape(A, (m, n * k))

        b = constraint_set['b']

        print("Initial A matrix : ", A)
        print("Initial b/n vector : ", b)

        H1 = extreme_points(A)
        H2 = remainder_columns(A, H1)

        step_size = 0
        print("Shape of H1 :", H1.shape)
        print("Shape of H2 :", H2.shape)

        # ###THIS BIT PUSHES b into conv(H2), in case it wasn't already in it
        #
        # is_in_hull = False
        # while not is_in_hull:
        #     weak_errors = np.ones((m, 1)) * ((2/k - 2/(k**2)) - step_size) * k
        #     constraint_set = set_up_constraint(weak_signals_probas, weak_errors)['error']
        #     b = constraint_set['b']
        #     is_in_hull = in_hull(H2, b)
        #     step_size -= 0.01
        # print("used weak error size after update so b is in Conv(H2):", (2/k - 2/(k**2)) - step_size)
        # ###THIS BIT PUSHES b into conv(H2), in case it wasn't already in it

        ###THIS BIT PUSHES b out of  conv(H2), in case it was already in it
        is_in_hull = True
        while is_in_hull:
            weak_errors = np.ones((m, 1)) * ((2 / k - 2 / (k ** 2)) - step_size) * k
            constraint_set = set_up_constraint(weak_signals_probas, weak_errors)['error']
            b = constraint_set['b']
            is_in_hull = in_hull(H2, b)
            step_size += 0.01

        print("used weak error size after update so b is out of Conv(H2):", (2 / k - 2 / k ** 2) - step_size)

        weak_errors = np.ones((m, 1)) * ((2 / k - 2 / (k ** 2)) - step_size) * k
        ###THIS BIT PUSHES b out of  conv(H2), in case it was already in it

        constraint_set = set_up_constraint(weak_signals_probas, weak_errors)['error']

        b = constraint_set['b']

        print("Updated b/n vector :", b)
        # print("Is b/n in Conv(H2)?", is_in_hull)
        print("Is b/n in Conv(H2)?", in_hull(H2, b))
        print("Is b/n in Conv(H1)?", in_hull(H1, b))
        # Compute b by multiplying by the number of data points, i.e. number of columns / number of classes
        print("dimension of b", b.shape)
        constraint_set['b'] = b * n
        print("dimension of constraint_set['b']", constraint_set['b'].shape)

        print("A used for RACH-Space:", np.reshape(constraint_set['A'], (m, n * k)))
        print("Dimension of A: ", constraint_set['A'].shape)

        print("b used for RACH-Space:", constraint_set['b'])

        # Add a row of ones with dimensions (1, n, k) to the bottom of error_amatrix
        ones_row = np.ones((1, n, k))
        constraint_set['A'] = np.concatenate((constraint_set['A'], ones_row), axis=0)

        constraint_set['b'] = np.vstack([constraint_set['b'], [[n]]])

        L = replace_negatives_with_fraction(L, k)

        self.model.fit(L, weak_signals_probas, constraint_set)

    def predict_proba(self, dataset: Union[BaseDataset, np.ndarray], **kwargs: Any) -> np.ndarray:
        L = check_weak_labels(dataset)

        n_class: Optional[int] = None
        if isinstance(dataset, BaseDataset):
            if n_class is not None:
                assert n_class == dataset.n_class
            else:
                n_class = dataset.n_class

        weak_signals_probas = convert_weak_signals_format_wrench_to_cll(L, n_class)

        weak_signals_probas = nonabstain_data(weak_signals_probas)

        # Limite the number of weak signals to be less or equal to 7
        weak_signals_probas = limit_rows(weak_signals_probas, method="mean_chunks")

        m, n, k = weak_signals_probas.shape

        L = replace_negatives_with_fraction(L, k)
        return self.model.predict_proba(L)


class RACH_Space_Algorithm(RACH_Space):
    def __init__(self, **kwargs):
        super().__init__(RACH_Space_model, **kwargs)

